#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import math

from .base_scheduler import BaseSchedulerPlugin


class CosineSchedulerPlugin(BaseSchedulerPlugin):
    """余弦退火学习率调度器插件"""
    
    def get_lr(self, current_step, total_steps, base_lr):
        """
        使用余弦退火公式计算当前学习率，与MiniMind保持完全一致
        
        Args:
            current_step (int): 当前步数
            total_steps (int): 总步数
            base_lr (float): 基础学习率
            
        Returns:
            float: 当前学习率
        """
        # 与MiniMind完全一致的实现
        return base_lr / 10 + 0.5 * base_lr * (1 + math.cos(math.pi * current_step / total_steps))